Abstract
Background We have run the simplified Naomi model using a range of inference methods.
Task In this report, we compare the accuracy of the posterior distributions obtained from these inference methods.
tmb <- readRDS("depends/tmb.rds")
aghq <- readRDS("depends/aghq.rds")
tmbstan <- readRDS("depends/tmbstan.rds")
All of the possible parameter names are as follows:
unique(names(tmb$fit$obj$env$par))
## [1] "beta_rho" "beta_alpha" "beta_lambda" "beta_anc_rho" "beta_anc_alpha"
## [6] "logit_phi_rho_x" "log_sigma_rho_x" "logit_phi_rho_xs" "log_sigma_rho_xs" "logit_phi_rho_a"
## [11] "log_sigma_rho_a" "logit_phi_rho_as" "log_sigma_rho_as" "log_sigma_rho_xa" "u_rho_x"
## [16] "us_rho_x" "u_rho_xs" "us_rho_xs" "u_rho_a" "u_rho_as"
## [21] "logit_phi_alpha_x" "log_sigma_alpha_x" "logit_phi_alpha_xs" "log_sigma_alpha_xs" "logit_phi_alpha_a"
## [26] "log_sigma_alpha_a" "logit_phi_alpha_as" "log_sigma_alpha_as" "log_sigma_alpha_xa" "u_alpha_x"
## [31] "us_alpha_x" "u_alpha_xs" "us_alpha_xs" "u_alpha_a" "u_alpha_as"
## [36] "u_alpha_xa" "OmegaT_raw" "log_betaT" "log_sigma_lambda_x" "ui_lambda_x"
## [41] "log_sigma_ancrho_x" "log_sigma_ancalpha_x" "ui_anc_rho_x" "ui_anc_alpha_x" "log_sigma_or_gamma"
## [46] "log_or_gamma"
data.frame(
"TMB" = tmb$time,
"aghq" = aghq$time,
"tmbstan" = tmbstan$time
)
histogram_plot <- function(par) {
df_compare <- rbind(
data.frame(method = "TMB", samples = as.numeric(tmb$fit$sample[[par]])),
data.frame(method = "aghq", samples = as.numeric(aghq$quad$sample[[par]])),
data.frame(method = "tmbstan", samples = as.numeric(unlist(rstan::extract(tmbstan$mcmc, pars = par))))
)
df_compare %>%
group_by(method) %>%
summarise(n = n())
ggplot(df_compare, aes(x = samples, fill = method, col = method)) +
geom_histogram(aes(y = after_stat(density)), alpha = 0.5, position = "identity") +
theme_minimal() +
facet_grid(method~.) +
labs(x = paste0(par), y = "Density", fill = "Method") +
scale_color_manual(values = multi.utils::cbpalette()) +
scale_fill_manual(values = multi.utils::cbpalette()) +
theme(legend.position = "none") +
labs(title = paste0(par))
}
histogram_plot("beta_anc_rho")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
to_ks_df <- function(par) {
samples_tmb <- t(tmb$fit$sample[[par]])
samples_aghq <- t(aghq$quad$sample[[par]])
samples_tmbstan <- as.data.frame(rstan::extract(tmbstan$mcmc, pars = par)[[par]])
n <- ncol(samples_tmbstan)
ks_tmb <- numeric(n)
ks_aghq <- numeric(n)
for(i in 1:n) {
ks_tmb[i] <- inf.utils::ks_test(samples_tmb[, i], samples_tmbstan[, i])$D
ks_aghq[i] <- inf.utils::ks_test(samples_aghq[, i], samples_tmbstan[, i])$D
}
rbind(
data.frame(method = "TMB", ks = ks_tmb, par = par, index = 1:n),
data.frame(method = "aghq", ks = ks_aghq, par = par, index = 1:n)
)
}
to_ks_df_2 <- function(par) {
all_par_names <- names(tmb$fit$obj$env$par)
par_names <- all_par_names[stringr::str_starts(all_par_names, par)]
unique_par_names <- unique(par_names)
samples_tmb <- tmb$fit$sample[unique_par_names]
samples_tmb <- lapply(samples_tmb, function(x) as.data.frame(t(x)))
samples_aghq <- aghq$quad$sample[unique_par_names]
samples_aghq <- lapply(samples_aghq, function(x) as.data.frame(t(x)))
samples_tmbstan <- rstan::extract(tmbstan$mcmc, pars = unique_par_names)
samples_tmbstan <- lapply(samples_tmbstan, function(x) as.data.frame(x))
table <- table(par_names)
unique_par_names <- unique(par_names)
for(par in unique_par_names) {
par_length <- table[par]
if(par_length > 1) {
par_colnames <- paste0(par, "[", 1:par_length, "]")
} else {
par_colnames <- paste0(par)
}
colnames(samples_tmb[[par]]) <- par_colnames
colnames(samples_aghq[[par]]) <- par_colnames
colnames(samples_tmbstan[[par]]) <- par_colnames
}
samples_tmb <- dplyr::bind_cols(samples_tmb)
samples_aghq <- dplyr::bind_cols(samples_aghq)
samples_tmbstan <- dplyr::bind_cols(samples_tmbstan)
n <- ncol(samples_tmbstan)
ks_tmb <- numeric(n)
ks_aghq <- numeric(n)
for(i in 1:n) {
ks_tmb[i] <- inf.utils::ks_test(samples_tmb[, i], samples_tmbstan[, i])$D
ks_aghq[i] <- inf.utils::ks_test(samples_aghq[, i], samples_tmbstan[, i])$D
}
rbind(
data.frame(method = "TMB", ks = ks_tmb, par = names(samples_tmbstan), index = 1:n),
data.frame(method = "aghq", ks = ks_aghq, par = names(samples_tmbstan), index = 1:n)
)
}
plot_ks_df <- function(ks_df) {
wide_ks_df <- pivot_wider(ks_df, names_from = "method", values_from = "ks") %>%
mutate(ks_diff = TMB - aghq)
mean_ks_diff <- mean(wide_ks_df$ks_diff)
boxplot <- wide_ks_df %>%
ggplot(aes(x = ks_diff)) +
geom_boxplot(width = 0.5) +
coord_flip() +
labs(
title = paste0("Mean KS difference is ", mean_ks_diff),
subtitle = ">0 then TMB more different to tmbstan, <0 then aghq more different",
x = "KS(TMB, tmbstan) - KS(aghq, tmbstan)"
) +
theme_minimal()
scatterplot <- ggplot(wide_ks_df, aes(x = TMB, y = aghq)) +
geom_point(alpha = 0.5) +
xlim(0, 0.5) +
ylim(0, 0.5) +
geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
labs(
title = paste0("KS tests for ", ks_df$par, " of length ", max(ks_df$index)),
subtitle = "Values along y = x have similar KS",
x = "KS(aghq, tmbstan)", y = "KS(TMB, tmbstan)"
) +
theme_minimal()
cowplot::plot_grid(scatterplot, boxplot, ncol = 2, rel_widths = c(1.3, 1))
}
ks_helper <- function(par) to_ks_df(par) %>% plot_ks_df()
ks_helper_2 <- function(par) to_ks_df_2(par) %>% plot_ks_df()
ks_helper_2("beta")
ks_helper_2("logit")
## Warning: Removed 8 rows containing missing values (`geom_point()`).
ks_helper_2("log_sigma")
## Warning: Removed 14 rows containing missing values (`geom_point()`).
ks_helper("u_rho_x")
ks_helper("u_rho_xs")
ks_helper("us_rho_x")
ks_helper("us_rho_xs")
ks_helper("u_rho_a")
ks_helper("u_rho_as")
ks_helper("u_alpha_x")
ks_helper("u_alpha_xs")
ks_helper("us_alpha_x")
ks_helper("us_alpha_xs")
ks_helper("u_alpha_a")
ks_helper("u_alpha_as")
ks_helper("u_alpha_xa")
ks_helper("ui_anc_rho_x")
ks_helper("ui_anc_alpha_x")
ks_helper("log_or_gamma")
ks_summary_table <- lapply(unique(names(tmb$fit$obj$env$par)), function(par) {
to_ks_df(par) %>%
group_by(method) %>%
summarise(ks = mean(ks), par = par[1])
}) %>%
bind_rows() %>%
pivot_wider(names_from = "method", values_from = "ks") %>%
rename(
"Parameter" = "par",
"KS(aghq, tmbstan)" = "aghq",
"KS(TMB, tmbstan)" = "TMB",
)
ks_summary_table %>%
gt::gt() %>%
gt::fmt_number(
columns = starts_with("KS"),
decimals = 3
)
| Parameter | KS(aghq, tmbstan) | KS(TMB, tmbstan) |
|---|---|---|
| beta_rho | 0.131 | 0.130 |
| beta_alpha | 0.156 | 0.144 |
| beta_lambda | 0.047 | 0.060 |
| beta_anc_rho | 0.119 | 0.118 |
| beta_anc_alpha | 0.068 | 0.044 |
| logit_phi_rho_x | 0.544 | 0.656 |
| log_sigma_rho_x | 0.458 | 0.734 |
| logit_phi_rho_xs | 0.230 | 0.685 |
| log_sigma_rho_xs | 0.898 | 0.868 |
| logit_phi_rho_a | 0.827 | 0.626 |
| log_sigma_rho_a | 0.499 | 0.586 |
| logit_phi_rho_as | 0.825 | 0.506 |
| log_sigma_rho_as | 0.295 | 0.507 |
| log_sigma_rho_xa | 0.440 | 0.690 |
| u_rho_x | 0.127 | 0.127 |
| us_rho_x | 0.119 | 0.120 |
| u_rho_xs | 0.231 | 0.225 |
| us_rho_xs | 0.097 | 0.094 |
| u_rho_a | 0.080 | 0.081 |
| u_rho_as | 0.167 | 0.173 |
| logit_phi_alpha_x | 0.430 | 0.653 |
| log_sigma_alpha_x | 0.833 | 0.632 |
| logit_phi_alpha_xs | 0.372 | 0.666 |
| log_sigma_alpha_xs | 0.688 | 0.649 |
| logit_phi_alpha_a | 0.834 | 0.501 |
| log_sigma_alpha_a | 0.367 | 0.501 |
| logit_phi_alpha_as | 0.742 | 0.514 |
| log_sigma_alpha_as | 0.241 | 0.539 |
| log_sigma_alpha_xa | 0.708 | 0.662 |
| u_alpha_x | 0.154 | 0.146 |
| us_alpha_x | 0.104 | 0.097 |
| u_alpha_xs | 0.130 | 0.124 |
| us_alpha_xs | 0.129 | 0.132 |
| u_alpha_a | 0.185 | 0.189 |
| u_alpha_as | 0.088 | 0.075 |
| u_alpha_xa | 0.099 | 0.097 |
| OmegaT_raw | 0.189 | 0.502 |
| log_betaT | 0.180 | 0.687 |
| log_sigma_lambda_x | 0.618 | 0.810 |
| ui_lambda_x | 0.269 | 0.267 |
| log_sigma_ancrho_x | 0.707 | 0.530 |
| log_sigma_ancalpha_x | 0.706 | 0.677 |
| ui_anc_rho_x | 0.078 | 0.081 |
| ui_anc_alpha_x | 0.128 | 0.127 |
| log_sigma_or_gamma | 0.489 | 0.574 |
| log_or_gamma | 0.098 | 0.099 |
ggplot(ks_summary_table, aes(x = `KS(TMB, tmbstan)`, y = `KS(aghq, tmbstan)`)) +
geom_point(alpha = 0.5) +
xlim(0, 1) +
ylim(0, 1) +
geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
theme_minimal()
#' To write!
#' To write!